{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Code for **\"Inpainting\"** figures $6$, $8$ and 7 (top) from the main paper. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "# os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "import numpy as np\n",
    "from models.resnet import ResNet\n",
    "from models.unet import UNet\n",
    "from models.skip import skip\n",
    "import torch\n",
    "import torch.optim\n",
    "\n",
    "from utils.inpainting_utils import *\n",
    "\n",
    "torch.backends.cudnn.enabled = True\n",
    "torch.backends.cudnn.benchmark =True\n",
    "dtype = torch.cuda.FloatTensor\n",
    "\n",
    "PLOT = True\n",
    "imsize = -1\n",
    "dim_div_by = 64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Choose figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Fig 6\n",
    "# img_path  = 'data/inpainting/vase.png'\n",
    "# mask_path = 'data/inpainting/vase_mask.png'\n",
    "\n",
    "## Fig 8\n",
    "# img_path  = 'data/inpainting/library.png'\n",
    "# mask_path = 'data/inpainting/library_mask.png'\n",
    "\n",
    "## Fig 7 (top)\n",
    "img_path  = 'data/inpainting/kate.png'\n",
    "mask_path = 'data/inpainting/kate_mask.png'\n",
    "\n",
    "# Another text inpainting example\n",
    "# img_path  = 'data/inpainting/peppers.png'\n",
    "# mask_path = 'data/inpainting/peppers_mask.png'\n",
    "\n",
    "NET_TYPE = 'skip_depth6' # one of skip_depth4|skip_depth2|UNET|ResNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_pil, img_np = get_image(img_path, imsize)\n",
    "img_mask_pil, img_mask_np = get_image(mask_path, imsize)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Center crop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_mask_pil = crop_image(img_mask_pil, dim_div_by)\n",
    "img_pil      = crop_image(img_pil,      dim_div_by)\n",
    "\n",
    "img_np      = pil_to_np(img_pil)\n",
    "img_mask_np = pil_to_np(img_mask_pil)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "img_mask_var = np_to_torch(img_mask_np).type(dtype)\n",
    "\n",
    "plot_image_grid([img_np, img_mask_np, img_mask_np*img_np], 3,11);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad = 'reflection' # 'zero'\n",
    "OPT_OVER = 'net'\n",
    "OPTIMIZER = 'adam'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if 'vase.png' in img_path:\n",
    "    INPUT = 'meshgrid'\n",
    "    input_depth = 2\n",
    "    LR = 0.01 \n",
    "    num_iter = 5001\n",
    "    param_noise = False\n",
    "    show_every = 50\n",
    "    figsize = 5\n",
    "    reg_noise_std = 0.03\n",
    "    \n",
    "    net = skip(input_depth, img_np.shape[0], \n",
    "               num_channels_down = [128] * 5,\n",
    "               num_channels_up   = [128] * 5,\n",
    "               num_channels_skip = [0] * 5,  \n",
    "               upsample_mode='nearest', filter_skip_size=1, filter_size_up=3, filter_size_down=3,\n",
    "               need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)\n",
    "    \n",
    "elif ('kate.png' in img_path) or ('peppers.png' in img_path):\n",
    "    # Same params and net as in super-resolution and denoising\n",
    "    INPUT = 'noise'\n",
    "    input_depth = 32\n",
    "    LR = 0.01 \n",
    "    num_iter = 6001\n",
    "    param_noise = False\n",
    "    show_every = 50\n",
    "    figsize = 5\n",
    "    reg_noise_std = 0.03\n",
    "    \n",
    "    net = skip(input_depth, img_np.shape[0], \n",
    "               num_channels_down = [128] * 5,\n",
    "               num_channels_up =   [128] * 5,\n",
    "               num_channels_skip =    [128] * 5,  \n",
    "               filter_size_up = 3, filter_size_down = 3, \n",
    "               upsample_mode='nearest', filter_skip_size=1,\n",
    "               need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)\n",
    "    \n",
    "elif 'library.png' in img_path:\n",
    "    \n",
    "    INPUT = 'noise'\n",
    "    input_depth = 1\n",
    "    \n",
    "    num_iter = 3001\n",
    "    show_every = 50\n",
    "    figsize = 8\n",
    "    reg_noise_std = 0.00\n",
    "    param_noise = True\n",
    "    \n",
    "    if 'skip' in NET_TYPE:\n",
    "        \n",
    "        depth = int(NET_TYPE[-1])\n",
    "        net = skip(input_depth, img_np.shape[0], \n",
    "               num_channels_down = [16, 32, 64, 128, 128, 128][:depth],\n",
    "               num_channels_up =   [16, 32, 64, 128, 128, 128][:depth],\n",
    "               num_channels_skip =    [0, 0, 0, 0, 0, 0][:depth],  \n",
    "               filter_size_up = 3,filter_size_down = 5,  filter_skip_size=1,\n",
    "               upsample_mode='nearest', # downsample_mode='avg',\n",
    "               need1x1_up=False,\n",
    "               need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)\n",
    "        \n",
    "        LR = 0.01 \n",
    "        \n",
    "    elif NET_TYPE == 'UNET':\n",
    "        \n",
    "        net = UNet(num_input_channels=input_depth, num_output_channels=3, \n",
    "                   feature_scale=8, more_layers=1, \n",
    "                   concat_x=False, upsample_mode='deconv', \n",
    "                   pad='zero', norm_layer=torch.nn.InstanceNorm2d, need_sigmoid=True, need_bias=True)\n",
    "        \n",
    "        LR = 0.001\n",
    "        param_noise = False\n",
    "        \n",
    "    elif NET_TYPE == 'ResNet':\n",
    "        \n",
    "        net = ResNet(input_depth, img_np.shape[0], 8, 32, need_sigmoid=True, act_fun='LeakyReLU')\n",
    "        \n",
    "        LR = 0.001\n",
    "        param_noise = False\n",
    "        \n",
    "    else:\n",
    "        assert False\n",
    "else:\n",
    "    assert False\n",
    "\n",
    "net = net.type(dtype)\n",
    "net_input = get_noise(input_depth, INPUT, img_np.shape[1:]).type(dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute number of parameters\n",
    "s  = sum(np.prod(list(p.size())) for p in net.parameters())\n",
    "print ('Number of params: %d' % s)\n",
    "\n",
    "# Loss\n",
    "mse = torch.nn.MSELoss().type(dtype)\n",
    "\n",
    "img_var = np_to_torch(img_np).type(dtype)\n",
    "mask_var = np_to_torch(img_mask_np).type(dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Main loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "i = 0\n",
    "def closure():\n",
    "    \n",
    "    global i\n",
    "    \n",
    "    if param_noise:\n",
    "        for n in [x for x in net.parameters() if len(x.size()) == 4]:\n",
    "            n = n + n.detach().clone().normal_() * n.std() / 50\n",
    "    \n",
    "    net_input = net_input_saved\n",
    "    if reg_noise_std > 0:\n",
    "        net_input = net_input_saved + (noise.normal_() * reg_noise_std)\n",
    "        \n",
    "        \n",
    "    out = net(net_input)\n",
    "   \n",
    "    total_loss = mse(out * mask_var, img_var * mask_var)\n",
    "    total_loss.backward()\n",
    "        \n",
    "    print ('Iteration %05d    Loss %f' % (i, total_loss.item()), '\\r', end='')\n",
    "    if  PLOT and i % show_every == 0:\n",
    "        out_np = torch_to_np(out)\n",
    "        plot_image_grid([np.clip(out_np, 0, 1)], factor=figsize, nrow=1)\n",
    "        \n",
    "    i += 1\n",
    "\n",
    "    return total_loss\n",
    "\n",
    "net_input_saved = net_input.detach().clone()\n",
    "noise = net_input.detach().clone()\n",
    "\n",
    "p = get_params(OPT_OVER, net, net_input)\n",
    "optimize(OPTIMIZER, p, closure, LR, num_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_np = torch_to_np(net(net_input))\n",
    "plot_image_grid([out_np], factor=5);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
